import torch


def get_optimizer_closure(model, loss_name="cross_entropy",
                          return_output=False):
    return lambda data, target: general_closure(model, data, target,
                                                loss_name, return_output)


def general_closure(model, data, target, loss_name, return_output=False):
    model.zero_grad()
    output = model(data)
    if loss_name == "cross_entropy":
        loss = torch.nn.CrossEntropyLoss()(output, target)
    else:
        assert False
    loss.backward()
    if return_output:
        return loss, output
    return loss
